package websocket

import (
	"sync"
	"time"

	"github.com/gogf/gf/frame/g"
	"github.com/gogf/gf/net/ghttp"
	"go.uber.org/ratelimit"
)

// 全局对象，保存数据
type server struct {
	msgHandlers   map[MessageOpcode][]*MessageHandler // 注册的所有消息处理函数
	sessionLocker *sync.RWMutex                       // session锁
	sessions      map[uint64]*Session                 // session集合
	userLocker    *sync.RWMutex                       // 用户锁
	users         map[string]*User                    // 用户集合
	openHandle    MessageHandleFunc                   // 打开连接后最先处理
	closeHandle   MessageHandleFunc                   // 关闭连接后最后处理
}

const (
	maxMessageSize = 4096 // 消息最大长度
	maxMessageRate = 20   // 每个链接的消息频率
	pongWait       = 45 * time.Second
)

var (
	wsServer       *server         // server 单例
	wsServerLocker = &sync.Mutex{} // server 锁
)

func Server() *server {
	if wsServer != nil {
		return wsServer
	}

	wsServerLocker.Lock()
	defer wsServerLocker.Unlock()

	if wsServer != nil {
		return wsServer
	}

	cC := new(Controller)
	wsServer = &server{
		sessionLocker: &sync.RWMutex{},
		sessions:      make(map[uint64]*Session),
		userLocker:    &sync.RWMutex{},
		users:         make(map[string]*User),
		msgHandlers:   make(map[MessageOpcode][]*MessageHandler),
		openHandle:    cC.Open,
		closeHandle:   cC.Close,
	}
	wsServer.AddHandler(SessionLevelConnect, OpCAuthLogin, cC.Auth)          // 授权登录
	wsServer.AddHandler(SessionLevelConnect, OpCAuthAnonymous, cC.Anonymous) // 匿名登录

	go wsServer.loopSending()

	return wsServer
}

// HandleWebsocket 处理websocket请求
func (s *server) HandleWebsocket(r *ghttp.Request) {
	ws, err := r.WebSocket()
	if err != nil {
		r.Exit()
		g.Log().Debug(`ws打开连接`, err)
		return
	}

	// 设置消息大小限制和超时
	ws.SetReadLimit(maxMessageSize)
	_ = ws.SetReadDeadline(time.Now().Add(pongWait))

	// 收到连接消息
	session := NewSession(r, ws)
	openMsg := NewOpenMessage(r, session)
	s.openHandle(session, openMsg)
	s.ReceiveMessage(session, openMsg)

	// 断开连接的消息
	defer func() {
		session.LostConnection()
		closeMsg := NewCloseMessage(r, session)
		s.ReceiveMessage(session, closeMsg)
		s.closeHandle(session, closeMsg)
		if err := ws.Close(); err != nil {
			g.Log().Debug(`ws断开连接`, err)
		}
	}()

	// 每个链接消息频率速率
	rateLimit := ratelimit.New(maxMessageRate)

	for {
		rateLimit.Take()

		_, msg, err := ws.ReadMessage()
		if err != nil {
			break
		}

		// 重新设置超时时间
		_ = ws.SetReadDeadline(time.Now().Add(pongWait))

		if string(msg) == `ping` {
			session.SendTextMessage([]byte(`pong`))
		} else if message, _ := NewReqMessage(msg); message != nil {
			s.ReceiveMessage(session, message)
		}
	}
}

// ReceiveMessage 接收到消息并转发到处理函数
func (s *server) ReceiveMessage(session *Session, message *MessageReq) {
	wg := &sync.WaitGroup{}
	if handlers, ok := s.msgHandlers[message.Opcode]; ok {
		for _, handler := range handlers {
			wg.Add(1)
			go func(handler *MessageHandler) {
				defer wg.Done()
				defer func() {
					if e := recover(); e != nil {
						if err, ok := e.(*Error); ok {
							session.SendMessageFromError(err)
						} else {
							g.Log().Error(e)
						}
					}
				}()

				// 这里只判断session的等级，至于具体的身份信息，在具体业务中判断
				if session.Level >= handler.SessionLevel {
					handler.Func(session, message)
				} else {
					// 需要登录，或权限不足
					if handler.SessionLevel >= SessionLevelUser {
						session.SendMessageFromSys(OpSSessionLevelLoW, g.Map{`level`: handler.SessionLevel})
					} else {
						session.SendMessageFromSys(OpSNeedLogin)
					}
				}
			}(handler)
		}
	}
	wg.Wait()
}

// AddHandler 注册处理函数
func (s *server) AddHandler(level SessionLevel, opcode MessageOpcode, handle MessageHandleFunc) {
	if _, ok := s.msgHandlers[opcode]; !ok {
		s.msgHandlers[opcode] = []*MessageHandler{}
	}

	s.msgHandlers[opcode] = append(s.msgHandlers[opcode], &MessageHandler{
		Func:         handle,
		SessionLevel: level,
	})
}

// AddSession 添加连接
func (s *server) AddSession(session *Session) {
	s.sessionLocker.Lock()
	defer s.sessionLocker.Unlock()

	s.sessions[session.Id] = session
}

// DelSession 删除session
func (s *server) DelSession(session *Session) {
	s.sessionLocker.Lock()
	defer s.sessionLocker.Unlock()

	delete(s.sessions, session.Id)
}

// GetSessions 所有session
func (s *server) GetSessions() map[uint64]*Session {
	s.sessionLocker.RLock()
	defer s.sessionLocker.RUnlock()

	return s.sessions
}

func (s *server) GetSession(id uint64) *Session {
	s.sessionLocker.RLock()
	defer s.sessionLocker.RUnlock()

	if session, ok := s.sessions[id]; ok {
		return session
	}
	return nil
}

// AddUser 添加用户
func (s *server) AddUser(user *User) {
	s.userLocker.Lock()
	defer s.userLocker.Unlock()

	s.users[user.Id] = user
}

// DelUser 删除用户
func (s *server) DelUser(user *User) {
	s.userLocker.Lock()
	defer s.userLocker.Unlock()

	delete(s.users, user.Id)
}

// GetUsers 所有用户数据
func (s *server) GetUsers() map[string]*User {
	s.userLocker.RLock()
	defer s.userLocker.RUnlock()

	return s.users
}

// GetUser 获取用户
func (s *server) GetUser(id string) *User {
	s.userLocker.RLock()
	defer s.userLocker.RUnlock()

	if user, ok := s.users[id]; ok {
		return user
	}
	return nil
}
